{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Quantizing Llama 7B to W4A16 Using SparseML's OneShot Pathway\n",
    "\n",
    "This example notebook walks through how to quantize Llama 7B using SparseML. We apply int4 channel-wise quantization all Linear layers, using UltraChat 200k as a calibration dataset.\n",
    "\n",
    "This example requires at least 45GB of GPU memory to run. The memory requirement can be reduced to 32GB by setting `sequential_update: true` in the recipe definition, but this will increase the runtime significantly."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from sparseml.transformers import SparseAutoModelForCausalLM, oneshot"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "SparseML uses recipes to define configurations for different oneshot algorithms. Recipes can be defined as a string or a yaml file. A recipe consists of one or more sparsification or quantization algorithms, called modifiers in SparseML. Below we create a sample recipe for GPTQ quantization that only requires a single modifier.\n",
    "\n",
    "This modifier specifies that we should quantize the weights of each linear layer to 4 bits, using a symmetric channelwise quantization pattern. The lm-head will not be quantized even though it is a Linear layer, because it is included in the ignore list."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "recipe = \"\"\"\n",
    "quant_stage:\n",
    "    quant_modifiers:\n",
    "        GPTQModifier:\n",
    "            sequential_update: false\n",
    "            ignore: [\"lm_head\"]\n",
    "            config_groups:\n",
    "                group_0:\n",
    "                    weights:\n",
    "                        num_bits: 4\n",
    "                        type: \"int\"\n",
    "                        symmetric: true\n",
    "                        strategy: \"channel\"\n",
    "                    targets: [\"Linear\"]\n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next we need to initialize the model we wish to quantize, and define a dataset for calibration. We will use a llama2 7b model that has been pretrained on the ultrachat 200k dataset. We will use the same dataset the model has been pretrained on for our one shot calibration. \n",
    "\n",
    "SparseML supports several datasets, such as ultrachat-200k, out of the box. You can also pass in a tokenized `datasets.Dataset` object for custom dataset support."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# by setting the device_map to auto, we can spread the model evenly across all available GPUs\n",
    "# load the model in as bfloat16 to save on memory and compute\n",
    "model_stub = \"zoo:llama2-7b-ultrachat200k_llama2_pretrain-base\"\n",
    "model = SparseAutoModelForCausalLM.from_pretrained(model_stub, torch_dtype=torch.bfloat16, device_map=\"auto\")\n",
    "\n",
    "# uses SparseML's built-in preprocessing for ultra chat\n",
    "dataset = \"ultrachat-200k\"\n",
    "\n",
    "# save location of quantized model\n",
    "output_dir = \"./output_llama7b_W4A16_channel_compressed\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we will configure our calibration dataset. To save on load time, we load only a small subset of ultrachat200k's `train_gen` split and label it as calibration data. For oneshot we do not need to pad the input, so we set `pad_to_max_length` to false. We also truncate each sample to a maximum of 512 tokens and select 512 samples for calibration. \n",
    "\n",
    "Using more calibration samples can improve model performance but will take longer to run. Generally 256-2048 calibration samples is recommended."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# set dataset config parameters\n",
    "splits = {\"calibration\": \"train_gen[:5%]\"}\n",
    "max_seq_length = 512\n",
    "pad_to_max_length = False\n",
    "num_calibration_samples = 512"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Finally, we can launch our quantization recipe using the `oneshot` function. This function call will apply the algorithms defined in `recipe` to the input `model`, using `num_calibration_samples` from `dataset` as calibration data. We will save the quantized model to `output_dir`.\n",
    "\n",
    "By setting `save_compressed` to True, the model will be saved by packing every 8 int4 weights into a single int32. This will enable the model to be loaded by vLLM. Once a model has been saved in this way, you can no longer recover the original unquantized weights. To save the model in a \"fake quantized\" state instead so that the original weights are preserved, set `save_compressed` to False."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "oneshot(\n",
    "    model=model,\n",
    "    dataset=dataset,\n",
    "    recipe=recipe,\n",
    "    output_dir=output_dir,\n",
    "    splits=splits,\n",
    "    max_seq_length=max_seq_length,\n",
    "    pad_to_max_length=pad_to_max_length,\n",
    "    num_calibration_samples=num_calibration_samples,\n",
    "    save_compressed=True\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The quantized model should now be stored in the defined `output_dir`. Its `config.json` will contain a new `compression_config` field that describes how the model has been quantized. This config will be used to load the model into vLLM."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "output_dir"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.save_pretrained(\"llama1.1b_W4A16_channel_packed\", save_compressed=True)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
